Abstract

In this report we will examine the study “Distributed coding of choice, action, and engagement across the mouse brain” by Steinmetz et al. (2019). In the study mice neural activity was recorded using neuro pixel probes as they underwent a task involving visual stimuli. We only use a subset of the data in this report, 18 of the original sessions and 4 of the mice. We will begin by doing exploratory data analysis on the data to visualize neural activity and identify trends across variables. Then we proceed with data integration to combine trial data across session and mice. Lastly we aim to build a machine learning model to predict the success or failure of the mice in the task and then evaluate model performance on test data.

Introduction

The study “Distributed coding of choice, action, and engagement across the mouse brain” by Steinmetz et al. (2019) focuses on how neural signals across the brain work to encode vision, choice and action. Using neuropixel probes their goal was to map these neural signals across different brain regions. Experiments were performed on 10 mice over 39 sessions, with hundreds of trials in each session. The task was simple. Mice were placed in front of two screens, one on their left and one on their right. The visual stimuli was presented on these screens with contrast levels of the values {0, 0.25, 0.5, 1}, with 0 being the absence of a stimulus. In front of the mice was also a turnable wheel. Success and failure in the task were defined as:

  • When left contrast > right contrast, success (1) if turning the wheel to the right and failure (-1) otherwise.
  • When right contrast > left contrast, success (1) if turning the wheel to the left and failure (-1) otherwise.
  • When both left and right contrasts are zero, success (1) if holding the wheel still for 1.5 seconds and failure (-1) otherwise.
  • When left and right contrasts are equal but non-zero, left or right will be randomly chosen (50%) as the correct choice.

If the mouse succeeded in the task they were rewarded. The activity of the neurons were recorded as spike trains, which is the timestamp when the neuron was fired. In this report we will examine what other variables were collected, the relationships between them, how we can integrate the data and the best way to build a predictive model. We begin by examining some frequency tables and the structure of the different variables. Then we try to integrate across the 18 sessions and 4 mice to better plot our data. In EDA, we plot linegraphs to examine the spike activity over time split by brain areas and then we aggregate the data over brain areas and split based on success rate. We see interesting patterns emerge in these plots. Then we plot bargraphs for the success rate based on contrast relationship to see if any other patterns emerge. In our predictive modeling section, we begin with a few basic models like logistic regression and k nearest neighbor and build on them using different combinations of predictor variables, LASSO, interaction terms, XGBoost and ensemble model. Finally, we test our optimized model on a test data set to evaluate the accuracy of the final model.

EDA- Tables

In this section we will start with doing some basic statistical exploration to better understand the data. Let’s first look at the summary statistics of the data set.

Basics

The data set we are working with is a set of 18 sessions where each session is a collection of trials.
From the statistical summary we see there are 8 variables in each session: contrast_left, contrast_right, feedback_type, mouse_name, brain_area, date_exp, spks, and time. contrast_left, contrast_right, and feedback_type are numerical variables. mouse_name, brain_area, and date_exp are categorical. spks (spike count) and time are lists. But not just a list of numbers, they are a list of matrices where each spike has a time bin. This structure can present some challenges in the code later, so it is important to notice this now.

##                Length Class  Mode     
## contrast_left  114    -none- numeric  
## contrast_right 114    -none- numeric  
## feedback_type  114    -none- numeric  
## mouse_name       1    -none- character
## brain_area     734    -none- character
## date_exp         1    -none- character
## spks           114    -none- list     
## time           114    -none- list

Frequency Tables

To further explore our variables we can construct frequency tables of each to get a better grasp of the distributions of each variable across sessions. Hopefully this will provide some insights.

contrast_left

It looks like contrast_left has four levels: 0, 0.25, 0.5 and 1. These represent the contrast level of the stimuli presented to the mouse on the left screen.

contrast_right

It looks like contrast_right has four levels: 0, 0.25, 0.5 and 1. These represent the contrast level of the stimuli presented to the mouse on the right screen.

feedback_type

It looks like feedback_type has 2 levels: -1 and 1. 1 is a success, if the mouse completed the task correctly, and -1 is a failure. It might be useful to later convert this into binary data of the form 0 and 1 which can be easier to work with.

mouse_name

It looks like there are 4 different mice across these sessions: Cori, Forssmann, Hench and Lederberg. Each mouse has a few sessions with their data:
Cori: Sessions 1-3
Forssmann: Sessions 4-7
Hench: Sessions 8-11
Lederberg: Sessions 12-18

brain_area

It looks like this data table is only showing the common brain areas across all 18 sessions. One thing to notice is that the brain areas where spks were recorded differ across sessions. Integrating this may be difficult and it might be useful to remove this variable altogether if we can’t find a successful way of integrating it across sessions.

date_exp

There is just one date per session. This does not seem to be very relevant data as all the session were conducted within a year so there isn’t enough information for a pattern to emerge. We will likely ignore this variable from now on.

Data Integration

Integrating the data and feature engineering is important so that we can better manipulate the data. The overall structure of the data is many trials within each session. We will need to explore this further and see if there’s new columns we should add to the data frame or if we should split the sessions by trial instead. We also know spks and time have a difficult structure- a list of matrices- so we will need to find a way to work with this. There is also a meaningful relationship between contrast_left and contrast_right within the context of the study so it might be worthwhile to extract and create a new variable for this information. We do all of this in the following section.

Unique Counts

One thing we can look at is the unique brain areas and the neuron and trial count to get a better grasp of what each session is doing. We have also included success rate in the following data table:

Averaging Spikes Count

We mentioned earlier that spks will be hard to work with. Before we can do some data visualization, we should think about how we will handle the spks variable. One thing we can do is take the mean spike count across sessions.

Calculating Contrast Relationship

We are given contrast_left and contrast_right and the study explains how these variables relate to feedback_type. We can create 3 new variables that will make examining this relationship easier- the contrast relationship, absolute contrast difference, and success rate.
Contrast relationship will be a categorical variable and is a more general measure. Contrast difference is numerical and as such more quantitative and specific in nature. Success rate can tell us how often we expect the feedback_type to be 1 (success).

EDA- Plots

Linegraphs by Session

We can now start generating some plots to explore relationships among variables. We can plot mean spk activity by brain_area first. These variables likely have some interesting patterns between them. First we plot be session and then by mouse_name. This will be a common pattern in our EDA as we want to find out if organizing by session or mouse_name (which there are multiple sessions per mouse) will help us create a better predictive model.
The following three plots are just of the first 3 sessions which are all Cori’s data.

Session 1

Noticeably in this plot CA3 and DG show high activity with mean spike count jumping to 0.07. There does not however appear to be much correlation or overlap between activity in separate brain areas.

Session 2

In Session 2 there appears to be a lot of overlap between the spike activity in different brain areas. There are also different and fewer brain areas recorded. However they seem to follow the same pattern. Perhaps since this is Session 2, Cori has learned what to expect from the task and is learning. The maximum mean spike count is also lower now, with only root reaching 0.05.

Session 3

By Session 3 we expect Cori to have learned and have a lower average mean spike count across the graph. But a new brian area- LP- shows incredibly high mean spike count across time, even reaching 0.3. Perhaps it is better to aggregate all of Cori’s sessions and see if there are any meaningful patterns there, whcih is what we do in the next section.

Linegraphs by Mouse

We can try splitting this data by mouse_name instead, which should provide more insights. Here is the breakdown again:
Cori: Sessions 1-3
Forssmann: Sessions 4-7
Hench: Sessions 8-10
Lederberg: Sessions 12-18
These plots will display the mean spk activity by brain_area over time for each individual mouse. What’s nice about these plots is that the spk activity is layered so that we can track differences over time. A disadvantage to our plots by session is that some brain areas have data from 1 session, others from 2 or 3 or more. We don’t know because there are different brain areas recorded in each session. So, we can create a graph that combines data across sessions with the common factor of mouse_name to address this.

Cori 1

In this plot we see LP as an obvious outlier reaching very high mean spike activity. There seems to be too much noise going on in this plot with all the overlap, and it’s hard to find meaningful insights. We can try separating the brain areas into different plots to see if this helps, which we do in the next tab.

Cori 2

This graph separates the data of each brain_area onto different graphs. There are many meaningful insights we can instantly draw from these plots. First, root is the only brain_area common to all three sessions. However there are a few other brain areas that are common to only two sessions. Also, certain brain areas like VISI, VISpm and SPF have similar spike pattern.

## `geom_smooth()` using method = 'loess' and formula = 'y ~ x'

Lederberg 1

We can do the same analysis on mouse Lederberg. This plot for all of Lederberg’s sessions has a very large amount of data and is very hard to read. We can see some noticeable brain areas that have higher mean spike count like ACA and RN We can seperate the brain areas like we did for Cori to see if this improves visualization in the next tab.

Lederberg 2

While this graph helped us see patterns for Cori there are just too many sessions for Lederberg and too many brain areas. This graph doesn’t seem to give many meaningful insights. We do see the brain areas LD, MD, and RSP have similar spike fluctuations though.

## `geom_smooth()` using method = 'loess' and formula = 'y ~ x'

Linegraphs: Spike Count by Success

It seems that including every brain_area in our plots makes them too hard to read. We also ultimately want to investigate if there is a relationship between the success of a trial (feedback_type) and the brain activity (spks) of the mouse. We can try filtering out brain areas that are not significant and then take the mean spike count across time. We do this by using a threshold value of 5. Brain areas where mean spike count is not above the threshold are thrown out. We can then split the spk counts into 2 groups: those that were concurrent with a feedback_type = 1 and those that happened when feedback_type = -1. This way we can investigate if there are any significant patterns that emerge in spks for the trials that were a success vs the failure trials. This can greatly help later when we build our prediction model.

Cori

We can instantly see a pattern in Cori’s plot. The average spike counts are similar up to time bin = 12 where they instantly separate. The spike count for success trials is much higher than that of failure trials. Although the fluctuations of the mean spike count is similar after time bin 12, the mean spike count for the success line is notably higher than the failure line.

Forssmann

We see something different with Forssmann. It is not until time bin = 16 that the success line is higher than the failure line. After that we see a much higher success line but still with similar fluctuations as the failure line, just at a higher mean spike count.

Hench

Similarly to Forssmann, it is not until time bin = 13 that the success line crosses and is greater than the failure line. In the first 16 time bins the mean spike count for the success trials is actually less than the failure line. After time = 13 the success mean spike counts grow. However, unlike Cori and Forssmann, the success line while higher does not fluctuate like the failure line. It actually mirrors it. When the success line dips the failure line rises below it and vice versa.

Lederberg

Here we find something interesting. Lederberg doesn’t have any overlap for the mean spike count of the failure and success trials. However we see a similar pattern to Cori that the fluctuations of the mean spike count is the same but the success line has much higher mean spike counts.

Bargraphs: Success Rate by Contrast Relationship

Another interesting relationship we can explore is that between feeback_type and varying measures of of the stimuli (contrast levels).
Keep in mind this aspect of the study:
- When left contrast > right contrast, success (1) if turning the wheel to the right and failure (-1) otherwise.
- When right contrast > left contrast, success (1) if turning the wheel to the left and failure (-1) otherwise.
- When both left and right contrasts are zero, success (1) if holding the wheel still and failure (-1) otherwise.
- When left and right contrasts are equal but non-zero, left or right will be randomly chosen (50%) as the correct choice.

As with the linegraphs we can split the data first by sessions and then by mouse to see if there are any insights. We will first look at contrast_relationship as a measure for the contrast levels.

Across Sessions

As expected the Equal Non-Zero category hovers around 50% because the correct choice in the study is randomly given. The Left > Right and Right > Left bins seem to have an overall higher success rate than the Both Zero. Maybe this was harder for the mice because succeeding in the Both Zero category meant no action from the mice whereas the Left > Right and Right > Left bins required an action from the mice. The mice might have found turning the wheel easier than not turning it.

Across Mice

Lederberg seems to be leading across three of the four bins. Forssmann also has high success rate across categories.

Predictive Modeling

Model Exploration

We can start by building some basic models to give us a starting point. We will try a logistic regression, k nearest neighbor and simple model first. Our predictor variables in the GLM and knn models will all be numeric and we will use: contrast_left, contrast_right, and average_spks to predict the binary variable: feedback_type. In our simple model we only use success_rate to predict feedback_type. If success_rate is over 50% it predicts success and if it’s under 50% it will predict a failure.

Surprisingly, the simple model had the highest accuracy at 71% followed by the logistic regression at 69.7% and the poorest model seems to be k nearest neighbor with a 66.3% accuracy. We should be able to do better than our simple model, so our goal will be to make a model that has over 71% accuracy. Since we also want to add mouse_name- a categorical variable- to our model, k nearest neighbor might not be the most fitting since it only takes numeric predictors We will continue our focus on the GLM model from now on and try to improve it to perform better.

## Logistic Regression Accuracy: 0.697
## k-NN Accuracy: 0.663
## Simple Model Accuracy (Based on Success Rate): 0.7101

Choosing Predictor Variables

When we plotted spks across time for the significant brain_areas and separated by feedback_type, we saw an obvious pattern emerge. We also saw in the bargraphs that success_rate varied across mouse_name. Let’s try adding these predictor variables to our model, like time which is attached to spks and mouse_name. We can keep contrast_left, contrast_right, and spks as predictors but we also added 40 time bins of average spikes and mouse_name as predictors too.

## GLM Accuracy with Time Bins, Spikes, Brain Area, & Mouse Name: 0.7001

If the 40 time-bin features are noisy or highly correlated, it might be better to aggregate them into a smaller number of features: early average, middle average and late average. However, on doing this our accuracy decreased to 69.9%.

## GLM Accuracy with Aggregated Time Features (Early, Mid, Late): 0.699

Choosing Significant Predictor Variables

We can combine these two models so that both time_bin and time as a categorical variable is included, and then use LASSO regularization to choose the significant variables for our model. This improves accuracy to 70.8%, which is great! This also means there are certain variables that are insignificant and even harming our model, which is important to keep in mind. Below we can see which variables LASSO kept.

## Optimal Lambda: 0.000600118
## Selected predictors by LASSO:
##  [1] "contrast_right"      "mouse_nameForssmann" "mouse_nameLederberg"
##  [4] "brain_areaLSr"       "brain_areaMB"        "brain_areaMG"       
##  [7] "brain_areaMOp"       "brain_areaMOs"       "brain_areaORBm"     
## [10] "brain_areaPL"        "brain_areaPO"        "brain_arearoot"     
## [13] "brain_areaTH"        "brain_areaVISp"      "brain_areaVISpm"    
## [16] "brain_areaVPM"       "brain_areaZI"        "success_rate"       
## [19] "time_bin_1"          "time_bin_2"          "time_bin_3"         
## [22] "time_bin_4"          "time_bin_5"          "time_bin_7"         
## [25] "time_bin_8"          "time_bin_9"          "time_bin_10"        
## [28] "time_bin_11"         "time_bin_12"         "time_bin_13"        
## [31] "time_bin_14"         "time_bin_15"         "time_bin_16"        
## [34] "time_bin_19"         "time_bin_20"         "time_bin_21"        
## [37] "time_bin_22"         "time_bin_23"         "time_bin_24"        
## [40] "time_bin_25"         "time_bin_26"         "time_bin_27"        
## [43] "time_bin_28"         "time_bin_29"         "time_bin_30"        
## [46] "time_bin_31"         "time_bin_32"         "time_bin_33"        
## [49] "time_bin_35"         "time_bin_36"         "time_bin_37"        
## [52] "time_bin_38"         "time_bin_39"         "time_bin_40"
## LASSO Model Accuracy: 0.708

Optimizing Model Performance

We can use a model called XGBoost instead of GLM to see if there’s a difference. XGBoost is a machine learning algorithm that uses gradient boosting on decision trees. It also has its own regularization methods, which we already know will be beneficial because LASSO removed variables. In this code we use the predictors contrast_left, contrast_right, brain_area, average_spks, mouse_name and the categorical time bins. We create model matrices via model.matrix() (which converts categorical variables into dummy variables), train an XGBoost classifier, and then compute test accuracy. After this, our accuracy does indeed improve to 72.3%.

## XGBoost Model Accuracy: 0.723

We can also try to build an Ensemble model that captures the best of both GLM and XGBoost models. After tuning the weights of the models, we get the highest accuracy when the XGBoost model is weighed at 0.7 and the GLM model is weighed at 0.3. Indeed this increased our accuracy to 72.8%.

## Ensemble Model (XGBoost + GLM) Accuracy: 0.728

Returning to our GLM model, we can try adding interaction terms to place emphasis on the combined effects of more than one variable. However, when we try this model accuracy decreases to 70.8% implying that interaction terms might actually be harming our model. We won’t add them for now and we can stick with our Ensemble model.

## GLM Accuracy with Interaction Terms: 0.708

Performance on Test Set & Discussion

Here we apply the Ensemble model to our test data. The test data is the last 100 trials removed from session 1 and the last 100 trials from session 18 (which were not included when we were building our predictive models). On applying our Ensemble model, we get a 72.5% accuracy which is around our expected accuracy of 72.8%. This is better than our 71% of the simple model we created, meaning it’s not a horrible model but just decent. We also did try to incorporate many combinations of variables to reach the highest accuracy we could. However, if accuracy could have been increased to above 85% this would have been a good model. The issue came when we seemed to reach a plateau at 72% where none of the optimization methods we tried would increase model performance. Given more time for model creation I would focus on investigating more specific relationships between these features to raise model accuracy.

## Ensemble Model Accuracy on Combined Test Data: 0.725

Acknowledgements

This project was done as a course project for STA141A (Fundamentals of Statistical Data Science) at University of California, Davis taught by Dr. Shizhe Chen during Winter 2025.
https://chatgpt.com/share/67d87728-616c-8011-9c37-ab8108718593
https://chatgpt.com/share/67d87747-db70-8011-b17c-cff7a31ae529
https://chatgpt.com/share/67d87423-055c-8011-aace-ed123da326b5
*One of the chats used focusing on EDA included pdf files, and sharing these chatgpt conversations is not yet supported. The chat history can be provided on request.

Code Appendix

# load libraries
library(knitr)
library(tidyverse)
library(purrr)
library(plotly)
library(DT)
library(glmnet)
library(xgboost)
library('class')

# set the working directory for all chunks
opts_knit$set(root.dir = "/Users/tatianacosta/STA141AProject")


# collect the code for our appendix
opts_chunk$set(echo = FALSE, warning = FALSE, cache = TRUE)
input_file = current_input()
if (is.null(input_file)) {
  input_file = "STA_141A_Final_Project.Rmd"
}
purl(input = input_file, output = "all_code.R", documentation = 0)

# basic raw data setup
session=list()
for(i in 1:18){
  session[[i]]=readRDS(paste('./Data/session',i,'.rds',sep=''))
}

# basic examination of data
summary(session[[1]])

## str(session[[1]])

# matrix with all contrast_left tables
contrast_left_all = do.call(rbind, lapply(1:18, function(i) {
  table(session[[i]]$contrast_left)
}))
rownames(contrast_left_all) = paste("Session", 1:18)
datatable(contrast_left_all, options = list(pageLength = 10))

# matrix with all contrast_right tables
contrast_right_all = do.call(rbind, lapply(1:18, function(i) {
  table(session[[i]]$contrast_right)
}))
rownames(contrast_right_all) = paste("Session", 1:18)
datatable(contrast_right_all, options = list(pageLength = 10))

# matrix with all feedback_type tables
feedback_type_all = do.call(rbind, lapply(1:18, function(i) {
  table(session[[i]]$feedback_type)
}))
rownames(feedback_type_all) = paste("Session", 1:18)
datatable(feedback_type_all, options = list(pageLength = 10))

# matrix with all mouse_names
mouse_names_all = sapply(1:18, function(i) {
  unique(session[[i]]$mouse_name)  
})
mouse_names_df = data.frame(Session = paste("Session", 1:18), 'Mouse Name' = mouse_names_all)
datatable(mouse_names_df, options = list(pageLength = 10))

# matrix with all brain_area tables
brain_area_all = do.call(rbind, lapply(1:18, function(i) {
  table(session[[i]]$brain_area)
}))
rownames(brain_area_all) = paste("Session", 1:18)
datatable(brain_area_all, options = list(pageLength = 10, scrollX = TRUE,
    autoWidth = TRUE))

# matrix with all date_exp tables
date_exp_all = sapply(1:18, function(i) {
  unique(session[[i]]$date_exp)  
})
date_exp_df = data.frame(Session = paste("Session", 1:18), Date_Exp = date_exp_all)
datatable(date_exp_df, options = list(pageLength = 10))

n.session = 18
# Create tibble for meta data
meta = tibble(
  mouse_name = rep('name', n.session),
  n_brain_area = rep(0, n.session),
  n_neurons = rep(0, n.session),
  n_trials = rep(0, n.session),
  success_rate = rep(0, n.session)
)

for (i in 1:n.session) {
  tmp = session[[i]]
  meta$mouse_name[i] = tmp$mouse_name
  meta$n_brain_area[i] = length(unique(tmp$brain_area))
  meta$n_neurons[i] = dim(tmp$spks[[1]])[1]
  meta$n_trials[i] = length(tmp$feedback_type)
  meta$success_rate[i] = round(mean(tmp$feedback_type + 1) / 2, 2)
}

datatable(meta)

# calculate spks mean and create data table
spks.trial = session[[1]]$spks[[1]]
total.spikes = apply(spks.trial, 1, sum)

final_data = tibble()

for (i in 1:n.session) {
  tmp = session[[i]]
  
  # Calculate total spikes for each trial (across all sessions)
  spks.trial = tmp$spks[[1]]
  total.spikes = apply(spks.trial, 1, sum)
  avg.spks = mean(total.spikes)
  
  # Create the final row of data for this session
  session_data = tibble(
    session_id = i,
    mouse_name = tmp$mouse_name,
    trial_id = 1:length(tmp$feedback_type),
    contrast_left = tmp$contrast_left,
    contrast_right = tmp$contrast_right,
    average_spks = round(avg.spks, 2), 
    feedback_type = tmp$feedback_type
  )
  
  # Combine the current session data with the final data
  final_data = bind_rows(final_data, session_data)
}

datatable(final_data, options = list(pageLength = 10, scrollX = TRUE,
    autoWidth = TRUE))

# calculate spks mean and create data table
spks.trial = session[[1]]$spks[[1]]
total.spikes = apply(spks.trial, 1, sum)

final_data1 = tibble()

for (i in 1:n.session) {
  tmp = session[[i]]
  
  # Calculate total spikes for each trial (across all sessions)
  spks.trial = tmp$spks[[1]]
  total.spikes = apply(spks.trial, 1, sum)
  avg.spks = mean(total.spikes)
  
  # Create the final row of data for this session with additional columns
  session_data = tibble(
    session_id = i,
    mouse_name = tmp$mouse_name,
    trial_id = 1:length(tmp$feedback_type),
    contrast_left = tmp$contrast_left,
    contrast_right = tmp$contrast_right,
    contrast_diff = abs(tmp$contrast_left - tmp$contrast_right),
    contrast_relationship = dplyr::case_when(
      tmp$contrast_left > tmp$contrast_right ~ "Left > Right",
      tmp$contrast_left < tmp$contrast_right ~ "Right > Left",
      tmp$contrast_left == tmp$contrast_right & tmp$contrast_left == 0 ~ "Both Zero",
      tmp$contrast_left == tmp$contrast_right ~ "Equal Non-Zero"
    ),
    average_spks = round(avg.spks, 2),
    feedback_type = tmp$feedback_type,
    success_rate = round(mean(tmp$feedback_type + 1) / 2, 2)
  )
  
  # Combine the current session data with the final data
  final_data1 = bind_rows(final_data, session_data)
}

datatable(final_data1, options = list(pageLength = 10, scrollX = TRUE,
    autoWidth = TRUE))


# Helper function: compute the mode (most frequent value)
mode_function <- function(x) {
  ux <- unique(x)
  ux[which.max(tabulate(match(x, ux)))]
}

# 1. Extract features per trial from all sessions and integrate into final_data
final_data <- map_dfr(seq_along(session), function(i) {
  sess <- session[[i]]
  n_trials <- length(sess$spks)
  # Compute session-level success_rate (convert -1 to 0)
  sr <- mean(ifelse(sess$feedback_type == 1, 1, 0))
  # Compute the mode for brain_area
  mode_brain <- { 
    ux <- unique(sess$brain_area)
    ux[which.max(tabulate(match(sess$brain_area, ux)))]
  }
  
  map_dfr(seq_len(n_trials), function(j) {
    trial_mat <- sess$spks[[j]]
    # Compute spike features: mean spike count per time bin (assumed 40 columns)
    spike_means <- colMeans(trial_mat, na.rm = TRUE)
    avg_spks <- mean(trial_mat, na.rm = TRUE)
    # Aggregated averages for time bins
    early_avg <- mean(spike_means[1:15], na.rm = TRUE)
    mid_avg   <- mean(spike_means[16:30], na.rm = TRUE)
    late_avg  <- mean(spike_means[31:40], na.rm = TRUE)
    # Determine time_category based on highest average
    agg <- c(early = early_avg, mid = mid_avg, late = late_avg)
    time_category <- names(agg)[which.max(agg)]
    
    # Extract time features from the corresponding time matrix
    time_mat <- sess$time[[j]]
    if (is.null(dim(time_mat)) || length(dim(time_mat)) < 2) {
      time_mat <- matrix(time_mat, nrow = 1)
    }
    time_means <- colMeans(time_mat, na.rm = TRUE)
    
    # Create a unique trial_id: "mouseName_sessionID_trialIdx"
    trial_id <- paste0(sess$mouse_name, "_", i, "_", j)
    
    data.frame(
      trial_id = trial_id,
      feedback_type = ifelse(sess$feedback_type[j] == -1, 0, sess$feedback_type[j]),
      contrast_left = sess$contrast_left[j],
      contrast_right = sess$contrast_right[j],
      mouse_name = sess$mouse_name,
      average_spks = avg_spks,
      brain_area = mode_brain,
      success_rate = sr,
      early_avg = early_avg,
      mid_avg = mid_avg,
      late_avg = late_avg,
      time_category = factor(time_category, levels = c("early", "mid", "late"))
    ) %>% bind_cols(
      as.data.frame(as.list(spike_means)) %>% setNames(paste0("time_bin_", seq_along(spike_means))),
      as.data.frame(as.list(time_means)) %>% setNames(paste0("time_", seq_along(time_means)))
    )
  })
})

# Remove any remaining NA values from final_data
final_data <- final_data %>% drop_na()


plot_spks = function(session_data) {
  session_number = which(sapply(session, identical, session_data))

  time_summary = do.call(rbind, lapply(session_data$spks, function(trial) 
    data.frame(
      time_bin = 1:ncol(trial),
      brain_area = rep(session_data$brain_area, each = ncol(trial)),
      spikes = as.vector(trial)
    )
  )) %>%
    group_by(time_bin, brain_area) %>%
    summarize(mean_spikes = mean(spikes, na.rm = TRUE), .groups = 'drop')

  suppressWarnings(
    plot_ly(time_summary, x = ~time_bin, y = ~mean_spikes, color = ~brain_area, 
            type = 'scatter', mode = 'lines', line = list(width = 2)) %>%
      layout(
        title = paste("Spike Activity Over Time by Brain Area (Session", session_number, ")"), 
        xaxis = list(title = "Time Bin"), 
        yaxis = list(title = "Mean Spike Count"), 
        legend = list(title = list(text = "Brain Area"))
      )
  )
}

plot_spks(session[[1]])

plot_spks(session[[2]])

plot_spks(session[[3]])

# plotting spks count over time for each brain area across all sessions for one mouse

get_mouse_sessions = function() {
  split(seq_along(session), sapply(session, function(s) unique(s$mouse_name)))
}

# Function to plot spike count over time for a specific mouse
plot_spks_mouse = function(mouse_name) {
  session_indices = get_mouse_sessions()[[mouse_name]]

  time_summary = do.call(rbind, lapply(session_indices, function(i) 
    do.call(rbind, lapply(session[[i]]$spks, function(trial)
      data.frame(
        time_bin = 1:ncol(trial),
        brain_area = rep(session[[i]]$brain_area, each = ncol(trial)),
        spikes = as.vector(trial),
        session_id = paste("Session", i)
      )
    ))
  )) %>%
    group_by(time_bin, brain_area) %>%
    summarize(mean_spikes = mean(spikes, na.rm = TRUE), .groups = 'drop')

  plot_ly(time_summary, x = ~time_bin, y = ~mean_spikes, color = ~brain_area, 
          type = 'scatter', mode = 'lines', 
          line = list(width = 2), 
          colors = RColorBrewer::brewer.pal(8, "Set2")) %>%  
    layout(title = paste("Spike Activity Over Time by Brain Area for", mouse_name),
           xaxis = list(title = "Time Bin"),
           yaxis = list(title = "Mean Spike Count"),
           legend = list(title = list(text = "Brain Area")))
}

plot_spks_mouse("Cori")

# spks & brain area over time addressing separate brain areas

plot_spks_mouse_seperate = function(mouse_name) {
  session_indices = get_mouse_sessions()[[mouse_name]]

  time_summary = do.call(rbind, lapply(session_indices, function(i) {
    session_data = session[[i]]
    do.call(rbind, lapply(session_data$spks, function(trial) {
      brain_areas = session_data$brain_area
      num_neurons = nrow(trial)

      data.frame(
        time_bin = 1:ncol(trial),
        brain_area = rep(brain_areas, each = ncol(trial)), 
        spikes = as.vector(trial),
        session_id = paste("Session", i)
      )
    }))
  })) %>%
    group_by(time_bin, brain_area, session_id) %>%
    summarize(mean_spikes = mean(spikes, na.rm = TRUE), .groups = 'drop')

  ggplot(time_summary, aes(x = time_bin, y = mean_spikes, color = session_id)) +
    geom_line(alpha = 0.5) + geom_smooth() +
    theme_minimal() +
    labs(
      title = paste("Spike Activity Over Time by Brain Area for", mouse_name),
      x = "Time Bin",
      y = "Mean Spike Count",
      color = "Session"
    ) +
    facet_wrap(~brain_area, scales = "free_y")
}
plot_spks_mouse_seperate("Cori")

plot_spks_mouse("Lederberg")

plot_spks_mouse_seperate("Lederberg")

plot_spks_significant = function(mouse_name, significance_threshold = 5) {
  session_indices = get_mouse_sessions()[[mouse_name]]

  # Extract spike data across sessions
  all_data = do.call(rbind, lapply(session_indices, function(i) {
    session_data = session[[i]]
    do.call(rbind, lapply(seq_along(session_data$spks), function(trial_idx) {
      data.frame(
        time_bin = 1:ncol(session_data$spks[[trial_idx]]),
        brain_area = rep(session_data$brain_area, each = ncol(session_data$spks[[trial_idx]])),
        spikes = colSums(session_data$spks[[trial_idx]], na.rm = TRUE),
        feedback_type = session_data$feedback_type[trial_idx]
      )
    }))
  }))

  # Identify significant brain areas
  significant_areas = all_data %>%
    group_by(brain_area) %>%
    summarize(mean_spikes = mean(spikes, na.rm = TRUE), .groups = 'drop') %>%
    filter(mean_spikes > significance_threshold) %>%
    pull(brain_area)

  # Exit early if no significant brain areas
  if (length(significant_areas) == 0) return(NULL)

  # Compute mean spikes across significant brain areas
  time_summary = all_data %>%
    filter(brain_area %in% significant_areas) %>%
    group_by(time_bin, feedback_type) %>%
    summarize(mean_spikes = mean(spikes, na.rm = TRUE), .groups = 'drop') %>%
    mutate(feedback_type = factor(feedback_type, levels = c(-1, 1), labels = c("Failure", "Success")))

  # Exit early if no data left to plot
  if (nrow(time_summary) == 0) return(NULL)

  # Create an interactive plot with plotly
  plot_ly(time_summary, x = ~time_bin, y = ~mean_spikes, color = ~feedback_type, 
          type = 'scatter', mode = 'lines', line = list(width = 2)) %>%
    layout(
      title = paste("Mean Spike Count Over Time (Significant Brain Areas) for", mouse_name),
      xaxis = list(title = "Time Bin"),
      yaxis = list(title = "Mean Spike Count"),
      legend = list(title = list(text = "Feedback Type"))
    )
}

plot_spks_significant('Cori')

plot_spks_significant('Forssmann')

plot_spks_significant('Hench')

plot_spks_significant('Lederberg')

# we can also analyze the success rates across sessions
# calculate success rates by contrast relationship for each session
success_rates = lapply(1:length(session), function(i) {
  session_data = session[[i]]
  session_data = session_data[c("contrast_left", "contrast_right", "feedback_type")]
  session_data = as.data.frame(session_data)
  
  session_data$contrast_relationship = dplyr::case_when(
    session_data$contrast_left > session_data$contrast_right ~ "Left > Right",
    session_data$contrast_left < session_data$contrast_right ~ "Right > Left",
    session_data$contrast_left == session_data$contrast_right & session_data$contrast_left == 0 ~ "Both Zero",
    session_data$contrast_left == session_data$contrast_right ~ "Equal Non-Zero"
  )
  
  success_rate_summary = session_data %>%
    group_by(contrast_relationship) %>%
    summarize(
      success_rate = mean(feedback_type == 1, na.rm = TRUE),
      .groups = 'drop'
    )
  
  success_rate_summary$session_id = i
  
  return(success_rate_summary)
})

success_rates_df = do.call(rbind, success_rates)

# Custom color scale to handle more than 8 unique session_id values
custom_colors <- RColorBrewer::brewer.pal(min(12, length(unique(success_rates_df$session_id))), "Set3")

# Now create an interactive plotly bar chart
plot <- suppressWarnings({
  plot_ly(success_rates_df, x = ~contrast_relationship, y = ~success_rate, 
          color = ~factor(session_id), type = 'bar', 
          text = ~paste("Success Rate: ", scales::percent(success_rate, accuracy = 0.1)),
          hoverinfo = 'text', 
          colors = custom_colors) %>%
    layout(title = "Success Rates by Contrast Relationship Across Sessions",
           xaxis = list(title = "Contrast Relationship"),
           yaxis = list(title = "Success Rate", tickformat = '%'),
           legend = list(title = list(text = "Session ID")),
           barmode = 'dodge')  # Move barmode here to avoid the 'bar' object warning
})

# Show the plot
plot


# success rate by contrast_RELATIONSHIP (MOUSE_NAME)
# Compute success rates grouped by mouse name (per session, unchanged)
success_rates = lapply(1:length(session), function(i) {
  session_data = session[[i]]  # Extract session data
  
  # Extract mouse name dynamically
  mouse_name = unique(session_data$mouse_name)
  
  # Select relevant columns and convert to data.frame
  session_data = as.data.frame(session_data[c("contrast_left", "contrast_right", "feedback_type")])
  
  # Define contrast relationship categories
  session_data$contrast_relationship = dplyr::case_when(
    session_data$contrast_left > session_data$contrast_right ~ "Left > Right",
    session_data$contrast_left < session_data$contrast_right ~ "Right > Left",
    session_data$contrast_left == session_data$contrast_right & session_data$contrast_left == 0 ~ "Both Zero",
    session_data$contrast_left == session_data$contrast_right ~ "Equal Non-Zero"
  )
  
  # Compute success rates for each contrast relationship (as a fraction between 0 and 1)
  success_rate_summary = session_data %>%
    group_by(contrast_relationship) %>%
    summarize(
      success_rate = mean(feedback_type == 1, na.rm = TRUE),
      .groups = 'drop'
    )
  
  # Add mouse name to the dataset
  success_rate_summary$mouse_name = mouse_name
  
  return(success_rate_summary)
})

# Combine data from all sessions
success_rates_df = do.call(rbind, success_rates)

# Now, to avoid adding success rates across sessions,
# we average them for each mouse and contrast relationship:
aggregated_df = success_rates_df %>% 
  group_by(mouse_name, contrast_relationship) %>% 
  summarize(success_rate = mean(success_rate), .groups = 'drop')

# Create the interactive plot with plotly, suppressing warnings
suppressWarnings({
  p <- plot_ly(
    data = aggregated_df,
    x = ~contrast_relationship,
    y = ~success_rate,
    color = ~mouse_name,
    type = 'bar'
  ) %>%
    layout(
      title = "Success Rates by Contrast Relationship Across Mice",
      xaxis = list(title = "Contrast Relationship", tickangle = 45),
      yaxis = list(title = "Success Rate", tickformat = ".0%"),
      barmode = "group"  # Grouped bars (dodge)
    )
})

p


# Select relevant predictors
model_data <- final_data %>%
  select(feedback_type, contrast_left, contrast_right, average_spks)

# Split into train/test (80% training)
set.seed(3)
split_index <- sample(seq_len(nrow(model_data)), size = 0.8 * nrow(model_data))
train_data <- model_data[split_index, ]
test_data  <- model_data[-split_index, ]


# Train logistic regression model
model_logit <- glm(feedback_type ~ contrast_left + contrast_right + average_spks, 
                   data = train_data, family = binomial)

# Predict & evaluate accuracy
pred_class_logit <- ifelse(predict(model_logit, test_data, type = "response") > 0.5, 1, 0)
accuracy_logit <- mean(pred_class_logit == test_data$feedback_type)
cat("Logistic Regression Accuracy:", round(accuracy_logit, 3), "\n")


# Normalize numerical predictors
normalize <- function(x) (x - min(x)) / (max(x) - min(x))
train_data_norm <- train_data %>%
  mutate(across(c(contrast_left, contrast_right, average_spks), normalize))
test_data_norm <- test_data %>%
  mutate(across(c(contrast_left, contrast_right, average_spks), normalize))

# Train k-NN model (k = 5)
preds_knn <- knn(train = train_data_norm[, -1], test = test_data_norm[, -1], 
                 cl = train_data$feedback_type, k = 5)

# Evaluate accuracy
accuracy_knn <- mean(preds_knn == test_data$feedback_type)
cat("k-NN Accuracy:", round(accuracy_knn, 3), "\n")


# Simple prediction model using success_rate
simple_preds <- ifelse(final_data$success_rate > 0.5, 1, 0)

# Calculate accuracy
accuracy_simple <- mean(simple_preds == final_data$feedback_type)
cat("Simple Model Accuracy (Based on Success Rate):", round(accuracy_simple, 4), "\n")


# Train-Test Split (80% training, 20% testing)
set.seed(3)
split_index <- sample(seq_len(nrow(final_data)), size = 0.8 * nrow(final_data))
train_data <- final_data[split_index, ]
test_data  <- final_data[-split_index, ]

# Convert categorical variables to factors
train_data <- train_data %>% mutate(mouse_name = as.factor(mouse_name), brain_area = as.factor(brain_area))
test_data  <- test_data %>% mutate(mouse_name = as.factor(mouse_name), brain_area = as.factor(brain_area))

# Define predictor variables dynamically
time_bin_vars <- paste0("time_bin_", 1:40)  # 40 time bins
predictor_vars <- c("contrast_left", "contrast_right", "mouse_name", "brain_area", "average_spks", time_bin_vars)

# Build formula dynamically
formula_str <- paste("feedback_type ~", paste(predictor_vars, collapse = " + "))
model_formula <- as.formula(formula_str)

# Train GLM Model
model_glm <- glm(model_formula, data = train_data, family = binomial)

# Predict & evaluate accuracy
preds_glm_prob <- predict(model_glm, newdata = test_data, type = "response")
pred_class_glm <- ifelse(preds_glm_prob > 0.5, 1, 0)
accuracy_glm <- mean(pred_class_glm == test_data$feedback_type)

cat("GLM Accuracy with Time Bins, Spikes, Brain Area, & Mouse Name:", round(accuracy_glm, 5), "\n")


# Train-Test Split (80% training, 20% testing)
set.seed(3)
split_index <- sample(seq_len(nrow(final_data)), size = 0.8 * nrow(final_data))
train_data <- final_data[split_index, ]
test_data  <- final_data[-split_index, ]

# Convert categorical variables to factors
train_data <- train_data %>% mutate(mouse_name = as.factor(mouse_name), brain_area = as.factor(brain_area))
test_data  <- test_data %>% mutate(mouse_name = as.factor(mouse_name), brain_area = as.factor(brain_area))

# Define predictor variables (Aggregated time features)
predictor_vars <- c("contrast_left", "contrast_right", "mouse_name", "brain_area", "success_rate", 
                    'average_spks', "early_avg", "mid_avg", "late_avg")

# Build formula dynamically
formula_str <- paste("feedback_type ~", paste(predictor_vars, collapse = " + "))
model_formula <- as.formula(formula_str)

# Train GLM Model
model_glm <- glm(model_formula, data = train_data, family = binomial)

# Predict & evaluate accuracy
preds_glm_prob <- predict(model_glm, newdata = test_data, type = "response")
pred_class_glm <- ifelse(preds_glm_prob > 0.5, 1, 0)
accuracy_glm <- mean(pred_class_glm == test_data$feedback_type)

cat("GLM Accuracy with Aggregated Time Features (Early, Mid, Late):", round(accuracy_glm, 3), "\n")


# Train-Test Split (80% training, 20% testing)
set.seed(3)
split_index <- sample(seq_len(nrow(final_data)), size = 0.8 * nrow(final_data))
train_data <- final_data[split_index, ]
test_data  <- final_data[-split_index, ]

# Convert categorical variables to factors
train_data <- train_data %>% mutate(mouse_name = as.factor(mouse_name), brain_area = as.factor(brain_area))
test_data  <- test_data %>% mutate(mouse_name = as.factor(mouse_name), brain_area = as.factor(brain_area))

# Define predictor variables (Aggregated time features)
predictor_vars <- c("contrast_left", "contrast_right", "mouse_name", "brain_area", 
                    "success_rate", paste0("time_bin_", 1:40), "average_spks", "early_avg", "mid_avg", "late_avg")

# Convert data into model matrix (This automatically creates dummy variables for factors)
x_train <- model.matrix(~ . -1, data = train_data[, predictor_vars])  # Remove intercept
x_test  <- model.matrix(~ . -1, data = test_data[, predictor_vars])  

# Target variable (feedback_type is already binary: 0 or 1)
y_train <- train_data$feedback_type
y_test  <- test_data$feedback_type

# Cross-validation to find optimal lambda (Regularization strength)
cv_fit <- cv.glmnet(x_train, y_train, family = "binomial", alpha = 1)  # Alpha = 1 means LASSO
best_lambda <- cv_fit$lambda.min
cat("Optimal Lambda:", best_lambda, "\n")

# Train final LASSO model using best lambda
lasso_model <- glmnet(x_train, y_train, family = "binomial", alpha = 1, lambda = best_lambda)

# Extract nonzero coefficients (selected predictors)
lasso_coeffs <- coef(lasso_model)
selected_idx <- which(lasso_coeffs != 0)
selected_predictors <- rownames(lasso_coeffs)[selected_idx]
selected_predictors <- setdiff(selected_predictors, "(Intercept)")  # Remove intercept

cat("Selected predictors by LASSO:\n")
print(selected_predictors)

# Predict probabilities on test set
preds_prob_lasso <- predict(lasso_model, newx = x_test, type = "response")
pred_class_lasso <- ifelse(preds_prob_lasso > 0.5, 1, 0)

# Calculate accuracy
accuracy_lasso <- mean(pred_class_lasso == y_test)
cat("LASSO Model Accuracy:", round(accuracy_lasso, 3), "\n")


# Ensure categorical variables are factors
final_data <- final_data %>%
  mutate(
    brain_area = as.factor(brain_area),
    mouse_name = as.factor(mouse_name),
    early_avg = rowMeans(select(., paste0("time_bin_", 1:15)), na.rm = TRUE),
    mid_avg   = rowMeans(select(., paste0("time_bin_", 16:30)), na.rm = TRUE),
    late_avg  = rowMeans(select(., paste0("time_bin_", 31:40)), na.rm = TRUE)
  )

# Define predictors (EXCLUDING success_rate)
predictor_vars <- c("contrast_left", "contrast_right", "brain_area", "mouse_name",
                    "average_spks", "early_avg", "mid_avg", "late_avg")

# Train-test split
set.seed(123)
split_index <- sample(seq_len(nrow(final_data)), size = 0.8 * nrow(final_data))
train_data <- final_data[split_index, ]
test_data  <- final_data[-split_index, ]

# One-hot encode categorical variables & normalize continuous ones
normalize <- function(x) (x - min(x)) / (max(x) - min(x))
x_train <- model.matrix(~ . -1, data = train_data[, predictor_vars]) %>% apply(2, normalize)
x_test  <- model.matrix(~ . -1, data = test_data[, predictor_vars]) %>% apply(2, normalize)

# Convert to XGBoost format
dtrain <- xgb.DMatrix(data = x_train, label = train_data$feedback_type)
dtest  <- xgb.DMatrix(data = x_test, label = test_data$feedback_type)

# Optimized XGBoost parameters
params <- list(
  objective = "binary:logistic",
  eval_metric = "error",
  max_depth = 4,
  eta = 0.1,
  subsample = 0.8,
  colsample_bytree = 0.8,
  lambda = 1
)

# Train the model
xgb_model <- xgb.train(params = params, data = dtrain, nrounds = 100, 
                       watchlist = list(train = dtrain, eval = dtest), verbose = 0)

# Predictions & Accuracy
preds_xgb_class <- ifelse(predict(xgb_model, dtest) > 0.5, 1, 0)
accuracy_xgb <- mean(preds_xgb_class == test_data$feedback_type)
cat("XGBoost Model Accuracy:", round(accuracy_xgb, 3), "\n")


# Ensure categorical variables are factors
final_data$brain_area <- as.factor(final_data$brain_area)
final_data$mouse_name <- as.factor(final_data$mouse_name)

# Add aggregated time bins
final_data <- final_data %>%
  mutate(
    early_avg = rowMeans(select(., paste0("time_bin_", 1:15)), na.rm = TRUE),
    mid_avg   = rowMeans(select(., paste0("time_bin_", 16:30)), na.rm = TRUE),
    late_avg  = rowMeans(select(., paste0("time_bin_", 31:40)), na.rm = TRUE)
  )

# Select predictor variables (Removing success_rate)
predictor_vars <- c("contrast_left", "contrast_right", "brain_area", "mouse_name",
                    "average_spks", "early_avg", "mid_avg", "late_avg")

# Train-test split
set.seed(123)
split_index <- sample(seq_len(nrow(final_data)), size = 0.8 * nrow(final_data))
train_data <- final_data[split_index, ]
test_data  <- final_data[-split_index, ]

# One-hot encode categorical variables for XGBoost
x_train <- model.matrix(~ . -1, data = train_data[, predictor_vars])
x_test  <- model.matrix(~ . -1, data = test_data[, predictor_vars])

# Define target variable
y_train <- train_data$feedback_type
y_test  <- test_data$feedback_type

# Convert to DMatrix for XGBoost
dtrain <- xgb.DMatrix(data = x_train, label = y_train)
dtest  <- xgb.DMatrix(data = x_test, label = y_test)

# Set XGBoost parameters
params <- list(
  objective = "binary:logistic",
  eval_metric = "error",
  max_depth = 3,
  eta = 0.05,
  subsample = 0.7,
  colsample_bytree = 0.7,
  lambda = 1
)

# Train XGBoost model
xgb_model <- xgb.train(params = params, data = dtrain, nrounds = 50,
                       watchlist = list(train = dtrain, eval = dtest), verbose = 0)

# Get XGBoost predictions for train and test sets
train_data$pred_xgb <- predict(xgb_model, dtrain)  # Add XGBoost predictions as a feature
test_data$pred_xgb <- predict(xgb_model, dtest)    # Add to test set

# Ensure feedback_type is binary for GLM
train_data$feedback_binary <- ifelse(train_data$feedback_type == 1, 1, 0)
test_data$feedback_binary  <- ifelse(test_data$feedback_type == 1, 1, 0)

# Train a GLM Model using XGBoost predictions as an extra feature
glm_model <- glm(feedback_binary ~ contrast_left + contrast_right + brain_area + pred_xgb,
                 data = train_data, family = binomial)

# Predict on test set
glm_preds <- predict(glm_model, newdata = test_data, type = "response")
final_preds <- ifelse(glm_preds > 0.5, 1, 0)

# Calculate accuracy
#accuracy_ensemble <- mean(final_preds == test_data$feedback_type)
#cat("Ensemble Model Accuracy:", round(accuracy_ensemble, 3), "\n")


# Ensure `x_test` is not empty
if (nrow(x_test) == 0 || ncol(x_test) == 0) {
  stop("Error: `x_test` is empty or incorrectly formatted!")
}

# Ensure `x_test` matches `xgb_model`'s features
model_features <- xgb_model$feature_names
missing_cols <- setdiff(model_features, colnames(x_test))

# Add missing columns filled with zeros
for (col in missing_cols) {
  x_test <- cbind(x_test, rep(0, nrow(x_test)))
  colnames(x_test)[ncol(x_test)] <- col
}

# Reorder `x_test` to match training features
x_test <- x_test[, model_features, drop = FALSE]

# Convert `x_test` to XGBoost DMatrix
dtest <- xgb.DMatrix(data = x_test)

# Check if `xgb_model` exists
if (!exists("xgb_model")) stop("Error: `xgb_model` does not exist!")

# Now make predictions safely
test_data$pred_xgb <- predict(xgb_model, dtest)

# Generate predictions
pred_glm <- predict(glm_model, newdata = test_data, type = "response")
pred_xgb <- predict(xgb_model, dtest)

# Compute ensemble prediction (weighted average)
ensemble_pred <- (0.7 * pred_xgb) + (0.3 * pred_glm)
final_ensemble_class <- ifelse(ensemble_pred > 0.5, 1, 0)

# Evaluate accuracy
accuracy_ensemble <- mean(final_ensemble_class == test_data$feedback_type)
cat("Ensemble Model (XGBoost + GLM) Accuracy:", round(accuracy_ensemble, 3), "\n")


# Define new GLM model with interaction terms
glm_model_int <- glm(feedback_type ~ contrast_left * contrast_right + brain_area * mouse_name +
                     early_avg * mid_avg + late_avg * brain_area + pred_xgb, 
                     data = train_data, family = binomial)

# Predict & evaluate accuracy
glm_preds_int <- predict(glm_model_int, newdata = test_data, type = "response")
final_preds_int <- ifelse(glm_preds_int > 0.5, 1, 0)
accuracy_glm_int <- mean(final_preds_int == test_data$feedback_type)
cat("GLM Accuracy with Interaction Terms:", round(accuracy_glm_int, 3), "\n")


# Function to load and preprocess test data
load_test_data <- function(file) {
  sess <- readRDS(file)
  n_trials <- length(sess$spks)

  # Compute mode of brain_area
  mode_brain <- {
    ux <- unique(sess$brain_area)
    ux[which.max(tabulate(match(sess$brain_area, ux)))]
  }

  # Extract features per trial
  test_data <- map_dfr(seq_len(n_trials), function(j) {
    trial_mat <- sess$spks[[j]]
    spike_means <- colMeans(trial_mat, na.rm = TRUE)
    avg_spks <- mean(trial_mat, na.rm = TRUE)

    # Aggregate spike counts over time bins
    early_avg <- mean(spike_means[1:15], na.rm = TRUE)
    mid_avg   <- mean(spike_means[16:30], na.rm = TRUE)
    late_avg  <- mean(spike_means[31:40], na.rm = TRUE)

    # Create a unique trial ID
    trial_id <- paste0(sess$mouse_name, "_", j)

    # Return processed data
    data.frame(
      trial_id = trial_id,
      feedback_type = ifelse(sess$feedback_type[j] == -1, 0, sess$feedback_type[j]),
      contrast_left = sess$contrast_left[j],
      contrast_right = sess$contrast_right[j],
      mouse_name = sess$mouse_name,
      average_spks = avg_spks,
      brain_area = mode_brain,
      early_avg = early_avg,
      mid_avg = mid_avg,
      late_avg = late_avg
    )
  })

  return(test_data)
}

# Load and combine test1.rds and test2.rds
test1 <- load_test_data("test1.rds")
test2 <- load_test_data("test2.rds")
test_data <- bind_rows(test1, test2)

# Ensure categorical variables are treated as factors
test_data <- test_data %>%
  mutate(
    brain_area = as.factor(brain_area),
    mouse_name = as.factor(mouse_name)
  )

# Define predictor variables based on trained model
predictor_vars <- c("contrast_left", "contrast_right", "brain_area", "mouse_name",
                    "average_spks", "early_avg", "mid_avg", "late_avg")

# One-hot encode categorical variables for XGBoost (excluding mouse_name)
predictor_vars_xgb <- setdiff(predictor_vars, "mouse_name")

# One-hot encode test data
x_test <- model.matrix(~ . -1, data = test_data[, predictor_vars_xgb])

# **Extract feature names from trained XGBoost model**
model_features <- colnames(xgb_model$feature_names)

# **Find missing columns in test data**
missing_cols <- setdiff(model_features, colnames(x_test))

# **Add missing columns to test data (fill with 0s)**
for (col in missing_cols) {
  x_test <- cbind(x_test, rep(0, nrow(x_test)))
  colnames(x_test)[ncol(x_test)] <- col
}

# **Reorder `x_test` to match XGBoost model features**
x_test <- x_test[, model_features, drop = FALSE]

# Convert to DMatrix for XGBoost
dtest <- xgb.DMatrix(data = x_test)

# Get XGBoost predictions (Assumes `xgb_model` is already in the environment)
test_data$pred_xgb <- predict(xgb_model, dtest)

# Get GLM predictions (Assumes `glm_model` is already in the environment)
glm_preds <- predict(glm_model, newdata = test_data, type = "response")

# Compute ensemble prediction (weighted average)
ensemble_pred <- (0.7 * test_data$pred_xgb) + (0.3 * glm_preds)
final_preds <- ifelse(ensemble_pred > 0.5, 1, 0)

# Compute accuracy
accuracy <- mean(final_preds == test_data$feedback_type)
cat("Ensemble Model Accuracy on Combined Test Data:", round(accuracy, 3), "\n")


# Read the tangled R script and display it inside a fenced code block
all_code = readLines("all_code.R")
cat("```r\n")
cat(paste(all_code, collapse = "\n"))
cat("\n```\n")